import torch

col = torch.Tensor(range(0, 40, 10)).view(4, 1)
row = torch.Tensor(range(5)).view(1, 5)
mat = col + row
print(mat.shape)
print(mat)

mat_ = mat.clone()
mat = mat.unsqueeze(2)

mat = mat.expand(-1, -1, 2)
print(mat.shape)
print(mat)

mat = mat_.unsqueeze(2).unsqueeze(3)

mat = mat.expand(-1, -1, 2, 3)
print(mat.shape)
print(mat)
